package com.example.mq.mqserver;

import com.example.mq.common.*;
import com.example.mq.mqserver.core.BasicProperties;
import lombok.extern.slf4j.Slf4j;

import java.io.*;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

//这个类来表示 消息队列的服务器
//本质上是一个 TCP服务器
@Slf4j
public class BrokerServer {
    //表示当前服务器持有的虚拟主机,在此只实现
    private VirtualHost virtualHost = new VirtualHost("default");
    //用这个哈希表表示所有会话关系
    //key 为 channelId, value 为 socket对象
    private ConcurrentHashMap<String, Socket> sessions = new ConcurrentHashMap<>();
    //引入一个线程池,处理多个用户的请求
    private ExecutorService executorService = null;
    //控制服务器是否继续运行
    private volatile boolean runnable = true;
    //表示服务的socket
    private ServerSocket serverSocket = null;

    //指定服务器的指定端口
    public BrokerServer(int port) throws IOException {
        serverSocket = new ServerSocket(port);
    }

    //启动服务器
    public void start() throws IOException {
        log.info("[BrokerServer] 服务器启动");
        executorService = Executors.newCachedThreadPool();
        try{
            while(runnable){
                Socket clientSocket = serverSocket.accept();
                executorService.submit(()->{
                    processConnection(clientSocket);
                });
            }
        } catch (SocketException e) {
            log.info("[BrokerServer] 服务器停止运行");
        }
    }

    //关闭服务器
    public void stop() throws IOException {
        runnable = false;
        //立即结束线程池中所有任务
        executorService.shutdown();
        serverSocket.close();
    }

    //处理客户端连接
    //在一次连接中,可能会需要处理多次请求和响应,使用while循环
    private void processConnection(Socket clientSocket) {
        try(InputStream inputStream = clientSocket.getInputStream();
            OutputStream outputStream = clientSocket.getOutputStream()) {
            //读取二进制的请求
            try (DataInputStream dataInputStream = new DataInputStream(inputStream);
                 DataOutputStream dataOutputStream = new DataOutputStream(outputStream)) {
                while (true) {
                    //1.读取请求并解析
                    Request request = readRequest(dataInputStream);
                    //2.根据请求计算响应
                    Response response = process(request, clientSocket);
                    //3.把响应写回客户端
                    writeResponse(dataOutputStream, response);
                }
            }
        } catch (EOFException | SocketException e){
            //DataInputStream 读取到末尾就会返回EOFException
            log.info("[BrokerServer] connection 关闭 客户端的地址: " + clientSocket.getInetAddress().toString()
                    + ":" + clientSocket.getPort());
        } catch (IOException | ClassNotFoundException | MqException e) {
            log.error("[BrokerServer] connection 出现异常!");
        } finally {
            try {
                clientSocket.close();
                //清理当前信道
                clearClosedSession(clientSocket);
            } catch (IOException e) {
                e.printStackTrace();
            }

        }
    }

    private void clearClosedSession(Socket clientSocket) {
        // 这里要做的事情, 主要就是遍历上述 sessions hash 表, 把该被关闭的 socket 对应的键值对, 统统删掉.
        List<String> toDeleteChannelId = new ArrayList<>();
        for (Map.Entry<String, Socket> entry : sessions.entrySet()) {
            if (entry.getValue() == clientSocket) {
                //集合类不能在遍历的时候删除
                toDeleteChannelId.add(entry.getKey());
            }
        }
        for (String channelId : toDeleteChannelId) {
            sessions.remove(channelId);
        }

    }

    private void writeResponse(DataOutputStream dataOutputStream, Response response) throws IOException {
        dataOutputStream.writeInt(response.getType());
        dataOutputStream.writeInt(response.getLength());
        dataOutputStream.write(response.getPayload());
        // 这个刷新缓冲区也是重要的操作!!
        dataOutputStream.flush();
    }

    private Response process(Request request, Socket clientSocket) throws IOException, ClassNotFoundException, MqException {
        BasicArguments basicArguments = (BasicArguments) BinaryTool.fromBytes(request.getPayload());
        log.info("[Request] rid=" + basicArguments.getRid() + ", channelId=" + basicArguments.getChannelId()
                + ", type=" + request.getType() + ", length=" + request.getLength());
        //根据约定的type,来区分这次请求操作
        boolean ok = true;
        if(request.getType() == 0x01){
            //创建channel
            sessions.put(basicArguments.getChannelId(), clientSocket);
        }else if(request.getType() == 0x02){
            //删除channel
            sessions.remove(basicArguments.getChannelId());
        }else if(request.getType() == 0x03){
            //创建交换机
            ExchangeDeclareArguments arguments = (ExchangeDeclareArguments) basicArguments;
            ok = virtualHost.exchangeDeclare(arguments.getExchangeName(), arguments.getExchangeType(),
                    arguments.isDurable(), arguments.isAutoDelete(), arguments.getArguments());
        }else if (request.getType() == 0x4) {
            //删除交换机
            ExchangeDeleteArguments arguments = (ExchangeDeleteArguments) basicArguments;
            ok = virtualHost.exchangeDelete(arguments.getExchangeName());
        } else if (request.getType() == 0x5) {
            //创建队列
            QueueDeclareArguments arguments = (QueueDeclareArguments) basicArguments;
            ok = virtualHost.queueDeclare(arguments.getQueueName(), arguments.isDurable(),
                    arguments.isExclusive(), arguments.isAutoDelete(), arguments.getArguments());
        } else if (request.getType() == 0x6) {
            //删除队列
            QueueDeleteArguments arguments = (QueueDeleteArguments) basicArguments;
            ok = virtualHost.queueDelete((arguments.getQueueName()));
        } else if (request.getType() == 0x7) {
            //创建绑定
            QueueBindArguments arguments = (QueueBindArguments) basicArguments;
            ok = virtualHost.queueBind(arguments.getQueueName(), arguments.getExchangeName(), arguments.getBindingKey());
        } else if (request.getType() == 0x8) {
            //解除绑定
            QueueUnbindArguments arguments = (QueueUnbindArguments) basicArguments;
            ok = virtualHost.queueUnbind(arguments.getQueueName(), arguments.getExchangeName());
        }else if (request.getType() == 0x9) {
            //发布消息
            BasicPublishArguments arguments = (BasicPublishArguments) basicArguments;
            ok = virtualHost.basicPublish(arguments.getExchangeName(), arguments.getRoutingKey(),
                    arguments.getBasicProperties(), arguments.getBody());
        }else if(request.getType() == 0xa){
            BasicConsumeArguments arguments = (BasicConsumeArguments) basicArguments;
            ok = virtualHost.basicConsume(arguments.getConsumeTag(), arguments.getQueueName(), arguments.isAutoAck(),
                    new Consumer() {
                        //这个回调函数,是让服务器把收到的消息推送给消费者
                        @Override
                        public void handleDelivery(String consumerTag, BasicProperties basicProperties, byte[] body) throws MqException, IOException {
                            Socket clientSocket = sessions.get(consumerTag);
                            if (clientSocket == null || clientSocket.isClosed()) {
                                throw new MqException("[BrokerServer] 订阅消息的客户端已经关闭!");
                            }
                            //构造响应的数据
                            SubscribeReturns subscribeReturns = new SubscribeReturns();
                            subscribeReturns.setChannelId(consumerTag);
                            subscribeReturns.setRid(""); // 由于这里只有响应, 没有请求, 不需要去对应. rid 暂时不需要.
                            subscribeReturns.setOk(true);
                            subscribeReturns.setConsumerTag(consumerTag);
                            subscribeReturns.setBasicProperties(basicProperties);
                            subscribeReturns.setBody(body);
                            byte[] payload = BinaryTool.toBytes(subscribeReturns);
                            Response response = new Response();
                            // 0xc 表示服务器给消费者客户端推送的消息数据.
                            response.setType(0xc);
                            // response 的 payload 就是一个 SubScribeReturns
                            response.setLength(payload.length);
                            response.setPayload(payload);
                            //此处dataOutputStream 这个对象不能 close
                            //如果 把 dataOutputStream 关闭, 就会直接把 clientSocket 里的 outputStream 也关了
                            //后续无法继续网socket中写数据了
                            DataOutputStream dataOutputStream = new DataOutputStream(clientSocket.getOutputStream());
                            writeResponse(dataOutputStream, response);
                        }
                    });
        }else if (request.getType() == 0xb) {
            // 调用 basicAck 确认消息.
            BasicAckArguments arguments = (BasicAckArguments) basicArguments;
            ok = virtualHost.basicAck(arguments.getQueueName(), arguments.getMessageId());
        } else {
            // 当前的 type 是非法的.
            throw new MqException("[BrokerServer] 未知的 type! type=" + request.getType());
        }
        // 构造响应
        BasicReturns basicReturns = new BasicReturns();
        basicReturns.setChannelId(basicArguments.getChannelId());
        basicReturns.setRid(basicArguments.getRid());
        basicReturns.setOk(ok);
        byte[] payload = BinaryTool.toBytes(basicReturns);

        Response response = new Response();
        response.setType(request.getType());
        response.setLength(payload.length);
        response.setPayload(payload);

        log.info("[Response] rid=" + basicReturns.getRid() + ", channelId=" + basicReturns.getChannelId()
                + ", type=" + response.getType() + ", length=" + response.getLength());
        return response;
    }

    private Request readRequest(DataInputStream dataInputStream) throws IOException {
        Request request = new Request();
        request.setType(dataInputStream.readInt());
        request.setLength(dataInputStream.readInt());
        byte[] payload = new byte[request.getLength()];
        int n = dataInputStream.read(payload);
        if (n != request.getLength()) {
            throw new IOException("读取请求格式出错!");
        }
        request.setPayload(payload);
        return request;
    }
}
